import torch
import torch.nn as nn
from typing import Callable
from spikingjelly.clock_driven.neuron import LIFNode
from spikingjelly.clock_driven import surrogate as surrogate_sj



class LIF(LIFNode):
    def __init__(self, tau: float = 2., decay_input: bool = False, v_threshold: float = 1.,
            v_reset: float = None, surrogate_function: Callable = surrogate_sj.PiecewiseQuadratic(),
            detach_reset: bool = False, cupy_fp32_inference=False, **kwargs):
        super().__init__(tau, decay_input, v_threshold, v_reset, surrogate_function, detach_reset, cupy_fp32_inference)
    def neuronal_charge(self, x: torch.Tensor):
        if self.decay_input:
            x = x / self.tau
        if self.v_reset is None or self.v_reset == 0:
            if type(self.v) is float:
                self.v = x
            else:
                self.v = self.v.detach() * (1 - 1. / self.tau) + x
        else:
            if type(self.v) is float:
                self.v = self.v_reset * (1 - 1. / self.tau) + self.v_reset / self.tau + x
            else:
                self.v = self.v.detach() * (1 - 1. / self.tau) + self.v_reset / self.tau + x

class LLIF(LIFNode):
    # long-term dependency enhanced LIF neuron
    def __init__(self, tau: float = 2., decay_input: bool = False, v_threshold: float = 1.,
            v_reset: float = None, surrogate_function: Callable = surrogate_sj.PiecewiseQuadratic(),
            detach_reset: bool = False, cupy_fp32_inference=False,  **kwargs):
        super().__init__(tau, decay_input, v_threshold, v_reset, surrogate_function, detach_reset, cupy_fp32_inference)
        self.ReLU = nn.ReLU()
    def neuronal_charge(self, x: torch.Tensor):
        if self.decay_input:
            x = x / self.tau
        if self.v_reset is None or self.v_reset == 0:
            if type(self.v) is float:
                self.v = x
            else:
                self.v = self.v.detach() * (1 - 1. / self.tau) + x
        else:
            if type(self.v) is float:
                self.v = self.v_reset * (1 - 1. / self.tau) + self.v_reset / self.tau + x
            else:
                self.v = self.v.detach() * (1 - 1. / self.tau) + self.v_reset / self.tau + x
    def v_threshold_float_to_tensor(self, x: torch.Tensor):
        if isinstance(self.v_threshold, float):
            v_threshold_init = self.v_threshold
            self.v_threshold = torch.full_like(x.data, v_threshold_init)
    def forward(self, x):
        self.v_threshold_float_to_tensor(x)
        mask = self.ReLU(x)
        mask = x - mask
        self.update_v_threshold(mask)
        return super().forward(x)
    def update_v_threshold(self, mask: torch.Tensor):
        with torch.no_grad():
            self.v_threshold = self.v_threshold - 0.1 * torch.tanh(mask)
            self.v_threshold = torch.clamp(self.v_threshold, 1.0, 2.0)

class SLIF(LIFNode):
    # Short-term dependency enhanced LIF neuron
    def __init__(self, tau: float = 2., decay_input: bool = False, v_threshold: float = 1.,
            v_reset: float = 1., surrogate_function: Callable = surrogate_sj.PiecewiseQuadratic(),
            detach_reset: bool = False, cupy_fp32_inference=False, **kwargs):
        super().__init__(tau, decay_input, v_threshold, v_reset, surrogate_function, detach_reset, cupy_fp32_inference)
    def neuronal_charge(self, x: torch.Tensor):
        if self.decay_input:
            x = x / self.tau
        if self.v_reset is None or self.v_reset == 0:
            if type(self.v) is float:
                self.v = x
            else:
                self.v = self.v.detach() * (1 - 1. / self.tau) + x
        else:
            if type(self.v) is float:
                self.v = self.v_reset * (1 - 1. / self.tau) + self.v_reset / self.tau + x
            else:
                self.v = self.v.detach() * (1 - 1. / self.tau) + self.v_reset / self.tau + x
    def tau_float_to_tensor(self, x: torch.Tensor):
        if isinstance(self.tau, float):
            tau_init = self.tau
            self.tau = torch.full_like(x.data, tau_init)
    def single_step_forward(self, x: torch.Tensor):
        self.v_float_to_tensor(x)
        self.tau_float_to_tensor(x)
        self.neuronal_charge(x)
        spike = self.neuronal_fire()
        self.neuronal_reset(spike)
        self.update_tau(self.v, spike)
        return spike
    def update_tau(self, v: torch.Tensor, spike: torch.Tensor):
        self.tau = self.tau - 0.1 * torch.tanh(v * spike)
        self.tau = torch.clamp(self.tau, 1.05, 2)